# -*- coding:utf-8 -*-

import os
import math
import numpy as np
import tensorflow as tf
from PIL import Image
import time


# VGG 自带的一个常量，之前VGG训练通过归一化，所以现在同样需要作此操作
VGG_MEAN = [103.939, 116.779, 123.68] # rgb 三通道的均值

class VGGNet():
    '''
    创建 vgg16 网络 结构
    从模型中载入参数
    '''
    def __init__(self, data_dict):
        '''
        传入vgg16模型
        :param data_dict: vgg16.npy (字典类型)
        '''
        self.data_dict = data_dict


    def get_conv_filter(self, name):
        '''
        得到对应名称的卷积层
        :param name: 卷积层名称
        :return: 该卷积层输出
        '''
        return tf.constant(self.data_dict[name][0], name = 'conv')

    def get_fc_weight(self, name):
        '''
        获得名字为name的全连接层权重
        :param name: 连接层名称
        :return: 该层权重
        '''
        return tf.constant(self.data_dict[name][0], name = 'fc')

    def get_bias(self, name):
        '''
        获得名字为name的全连接层偏置
        :param name: 连接层名称
        :return: 该层偏置
        '''
        return tf.constant(self.data_dict[name][1], name = 'bias')


    def conv_layer(self, x, name):
        '''
        创建一个卷积层
        :param x:
        :param name:
        :return:
        '''
        # 在写计算图模型的时候，加一些必要的 name_scope，这是一个比较好的编程规范
        # 可以防止命名冲突， 二可视化计算图的时候比较清楚
        with tf.name_scope(name):
            # 获得 w 和 b
            conv_w = self.get_conv_filter(name)
            conv_b = self.get_bias(name)

            # 进行卷积计算
            h = tf.nn.conv2d(x, conv_w, strides = [1, 1, 1, 1], padding = 'SAME')
            '''
            因为此刻的 w 和 b 是从外部传递进来，所以使用 tf.nn.conv2d()
            tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu = None, name = None) 参数说明：
            input 输入的tensor， 格式[batch, height, width, channel]
            filter 卷积核 [filter_height, filter_width, in_channels, out_channels] 
                分别是：卷积核高，卷积核宽，输入通道数，输出通道数
            strides 步长 卷积时在图像每一维度的步长，长度为4
            padding 参数可选择 “SAME” “VALID”
            
            '''
            # 加上偏置
            h = tf.nn.bias_add(h, conv_b)
            # 使用激活函数
            h = tf.nn.relu(h)
            return h


    def pooling_layer(self, x, name):
        '''
        创建池化层
        :param x: 输入的tensor
        :param name: 池化层名称
        :return: tensor
        '''
        return tf.nn.max_pool(x,
                              ksize = [1, 2, 2, 1], # 核参数， 注意：都是4维
                              strides = [1, 2, 2, 1],
                              padding = 'SAME',
                              name = name
                              )

    def fc_layer(self, x, name, activation = tf.nn.relu):
        '''
        创建全连接层
        :param x: 输入tensor
        :param name: 全连接层名称
        :param activation: 激活函数名称
        :return: 输出tensor
        '''
        with tf.name_scope(name, activation):
            # 获取全连接层的 w 和 b
            fc_w = self.get_fc_weight(name)
            fc_b = self.get_bias(name)
            # 矩阵相乘 计算
            h = tf.matmul(x, fc_w)
            #　添加偏置
            h = tf.nn.bias_add(h, fc_b)
            # 因为最后一层是没有激活函数relu的，所以在此要做出判断
            if activation is None:
                return h
            else:
                return activation(h)

    def flatten_layer(self, x, name):
        '''
        展平
        :param x: input_tensor
        :param name:
        :return: 二维矩阵
        '''
        with tf.name_scope(name):
            # [batch_size, image_width, image_height, channel]
            x_shape = x.get_shape().as_list()
            # 计算后三维合并后的大小
            dim = 1
            for d in x_shape[1:]:
                dim *= d
            # 形成一个二维矩阵
            x = tf.reshape(x, [-1, dim])
            return x

    def build(self, x_rgb):
        '''
        创建vgg16 网络
        :param x_rgb: [1, 224, 224, 3]
        :return:
        '''
        start_time = time.time()
        print('模型开始创建……')
        # 将输入图像进行处理，将每个通道减去均值
        r, g, b = tf.split(x_rgb, [1, 1, 1], axis = 3)
        '''
        tf.split(value, num_or_size_split, axis=0)用法：
        value:输入的Tensor
        num_or_size_split:有两种用法：
            1.直接传入一个整数，代表会被切成几个张量，切割的维度有axis指定
            2.传入一个向量，向量长度就是被切的份数。传入向量的好处在于，可以指定每一份有多少元素
        axis, 指定从哪一个维度切割
        因此，上一句的意思就是从第4维切分，分为3份，每一份只有1个元素
        '''
        # 将 处理后的通道再次合并起来
        x_bgr = tf.concat([b - VGG_MEAN[0], g - VGG_MEAN[1], r - VGG_MEAN[2]], axis = 3)

#        assert x_bgr.get_shape().as_list()[1:] == [224, 224, 3]

        # 开始构建卷积层
        # vgg16 的网络结构
        # 第一层：2个卷积层 1个pooling层
        # 第二层：2个卷积层 1个pooling层
        # 第三层：3个卷积层 1个pooling层
        # 第四层：3个卷积层 1个pooling层
        # 第五层：3个卷积层 1个pooling层
        # 第六层： 全连接
        # 第七层： 全连接
        # 第八层： 全连接

        # 这些变量名称不能乱取，必须要和vgg16模型保持一致
        # 另外，将这些卷积层用self.的形式，方便以后取用方便
        self.conv1_1 = self.conv_layer(x_bgr, 'conv1_1')
        self.conv1_2 = self.conv_layer(self.conv1_1, 'conv1_2')
        self.pool1 = self.pooling_layer(self.conv1_2, 'pool1')

        self.conv2_1 = self.conv_layer(self.pool1, 'conv2_1')
        self.conv2_2 = self.conv_layer(self.conv2_1, 'conv2_2')
        self.pool2 = self.pooling_layer(self.conv2_2, 'pool2')

        self.conv3_1 = self.conv_layer(self.pool2, 'conv3_1')
        self.conv3_2 = self.conv_layer(self.conv3_1, 'conv3_2')
        self.conv3_3 = self.conv_layer(self.conv3_2, 'conv3_3')
        self.pool3 = self.pooling_layer(self.conv3_3, 'pool3')

        self.conv4_1 = self.conv_layer(self.pool3, 'conv4_1')
        self.conv4_2 = self.conv_layer(self.conv4_1, 'conv4_2')
        self.conv4_3 = self.conv_layer(self.conv4_2, 'conv4_3')
        self.pool4 = self.pooling_layer(self.conv4_3, 'pool4')

        self.conv5_1 = self.conv_layer(self.pool4, 'conv5_1')
        self.conv5_2 = self.conv_layer(self.conv5_1, 'conv5_2')
        self.conv5_3 = self.conv_layer(self.conv5_2, 'conv5_3')
        self.pool5 = self.pooling_layer(self.conv5_3, 'pool5')

        ''' 因为风格转换只需要 卷积层  的数据
        self.flatten5 = self.flatten_layer(self.pool5, 'flatten')
        self.fc6 = self.fc_layer(self.flatten5, 'fc6')
        self.fc7 = self.fc_layer(self.fc6, 'fc7')
        self.fc8 = self.fc_layer(self.fc7, 'fc8', activation = None)
        self.prob = tf.nn.softmax(self.fc8, name = 'prob')
        '''


        print('创建模型结束：%4ds' % (time.time() - start_time))

# 指定 model 路径
vgg16_npy_pyth = './vgg16.npy'
# 内容图像 路径
content_img_path = './shanghai.jpg'
# 风格图像路径
style_img_path = './candy.jpg'

# 训练的步数
num_steps = 1000
# 指定学习率
learning_rate = 10

# 设置 两个 参数
lambda_c = 0.1
lambda_s = 5000

# 输入 目录
output_dir = './run_style_transfer'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)


def initial_result(shape, mean, stddev):
    '''
    定义一个初始化好的随机图片，然后在该图片上不停的梯度下降来得到效果。
    :param shape: 输入形状
    :param mean: 均值
    :param stddev: 方法
    :return: 图片
    '''
    initial = tf.truncated_normal(shape, mean = mean, stddev = stddev) # 一个截断的正态分布
    '''
    tf.truncated_normal(shape, mean, stddev) 生成截断的生态分布函数
    如果产生的正态分布值和均值差值大于二倍的标准差，那就重新生成。
    '''
    return tf.Variable(initial)

def read_img(img_name):
    '''
    读取图片
    :param img_name: 图片路径
    :return: 4维矩阵
    '''
    img = Image.open(img_name)
    # 图像为三通道（224， 244， 3），但是需要转化为4维
    np_img = np.array(img) # 224, 224, 3
    np_img = np.asarray([np_img], dtype = np.int32) # 将生成的列表转化为 array (1, 224, 224, 3)
    return np_img

def gram_matrix(x):
    '''
    计算 gram 矩阵
    :param x: 特征图，shape：[1, width, height, channel]
    :return:
    '''
    b, w, h, ch = x.get_shape().as_list()
    # 这里求出来的是 每一个feature map之间的相似度
    features = tf.reshape(x, [b, h * w, ch]) # 将二三维的维度合并，已组成三维
    # 相似度矩阵 方法： 将矩阵转置为[ch, b*w], 再乘原矩阵，最后的矩阵是[ch , ch]
    # 防止矩阵数值过大，除以一个常数
    gram = tf.matmul(features, features, adjoint_a = True) / tf.constant(ch * w * h, tf.float32) # 参数3， 表示将第一个参数转置
    return gram


# 生成一个图像，均值为127.5，方差为20
result = initial_result((1, 640, 1162, 3), 127.5, 20)

# 读取 内容图像 和 风格图像
content_val = read_img(content_img_path)
style_val = read_img(style_img_path)

content = tf.placeholder(tf.float32, shape = [1, 640, 1162, 3])
style = tf.placeholder(tf.float32, shape = [1, 1024, 1024, 3])

# 载入模型， 注意：在python3中，需要添加一句： encoding='latin1'
data_dict = np.load(vgg16_npy_pyth, encoding='latin1').item()


# 创建这三张图像的 vgg 对象
vgg_for_content = VGGNet(data_dict)
vgg_for_style = VGGNet(data_dict)
vgg_for_result = VGGNet(data_dict)

# 创建 每个 神经网络
vgg_for_content.build(content)
vgg_for_style.build(style)
vgg_for_result.build(result)

# 提取哪些层特征
# 需要注意的是：内容特征抽取的层数和结果特征抽取的层数必须相同
# 风格特征抽取的层数和结果特征抽取的层数必须相同
content_features = [vgg_for_content.conv1_2,
                    vgg_for_content.conv2_2,
                    # vgg_for_content.conv3_3,
                    # vgg_for_content.conv4_3,
                    # vgg_for_content.conv5_3,
                    ]

result_content_features = [vgg_for_result.conv1_2,
                          vgg_for_result.conv2_2,
                          # vgg_for_result.conv3_3,
                          # vgg_for_result.conv4_3,
                          # vgg_for_result.conv5_3
                          ]

# feature_size, [1, width, height, channel]
style_features = [# vgg_for_style.conv1_2,
                          # vgg_for_style.conv2_2,
                          # vgg_for_style.conv3_3,
                          vgg_for_style.conv4_3,
                          # vgg_for_style.conv5_3
                          ]

# 为列表中每一个元素，都计算 gram
style_gram = [gram_matrix(feature) for feature in style_features]

result_style_features = [# vgg_for_result.conv1_2,
                          # vgg_for_result.conv2_2,
                          # vgg_for_result.conv3_3,
                          vgg_for_result.conv4_3,
                          # vgg_for_result.conv5_3
                          ]

result_style_gram = [gram_matrix(feature) for feature in result_style_features]

content_loss = tf.zeros(1, tf.float32)
# 计算内容损失
# 卷积层的形状 shape:[1, width, height, channel], 需要在三个通道上做平均
for c, c_ in zip(content_features, result_content_features):
    content_loss += tf.reduce_mean((c - c_)**2, axis = [1, 2, 3])

# 风格内容损失

style_loss = tf.zeros(1, tf.float32)
for s, s_ in zip(style_gram, result_style_gram):
    # 因为在计算gram矩阵的时候，降低了一维，所以，只需要在[1, 2]两个维度求均值即可
    style_loss += tf.reduce_mean( (s - s_)** 2, [1, 2] )


# 总的损失函数
loss = content_loss * lambda_c + style_loss * lambda_s


train_op = tf.train.AdamOptimizer( learning_rate ).minimize(loss)


init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    for step in range(num_steps):
        loss_value, content_loss_value, style_loss_value, _ = \
            sess.run([loss, content_loss, style_loss, train_op],
                     feed_dict = {
                         content:content_val,
                         style:style_val
                     })
        # 因为loss_value等，是一个数组，需要通过索引将值去出
        print('step: %d, loss_value: %8.4f, content_loss: %8.4f, style_loss: %8.4f' % (step+1,
                                                                  loss_value[0],
                                                                  content_loss_value[0],
                                                                  style_loss_value[0]))
        result_img_path = os.path.join(output_dir, 'result_%05d.jpg'%(step+1))
        result_val = result.eval(sess)[0] # 将图像取出，因为之前是4维，所以需要使用一个索引0，将其取出

        result_val = np.clip(result_val, 0, 255)
        # np.clip() numpy.clip(a, a_min, a_max, out=None)[source]
        # 其中a是一个数组，后面两个参数分别表示最小和最大值

        img_arr = np.asarray(result_val, np.uint8)
        img = Image.fromarray(img_arr)
        img.save(result_img_path)
